/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/* \file
   \brief Helper functions for mapping CUTLASS concepts to cuDNN.

*/

#pragma once
#if CUTLASS_ENABLE_CUDNN
#include <cuda_runtime.h>
#include <cudnn.h>
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/library/library.h"
#include "enumerated_types.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace profiler {

/////////////////////////////////////////////////////////////////////////////////////////////////
/// Converts a cuDNN status to cutlass::Status
Status get_cutlass_status(cudnnStatus_t cudnn_status);

/// Converts a cuDNN status to cutlass::profiler::Disposition
Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status);

/// Checks cudnnStatus_t converts to cutlas status and returns if
/// Status::kSuccess o.w. throws exception
Status checkCudnnErr(cudnnStatus_t cudnn_status);

/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration
bool get_cudnn_conv_mode(cudnnConvolutionMode_t& cudnn_conv_mode,
                         conv::Mode conv_mode);

/// Maps a CUTLASS layout type to a cuDNN data type enumeration
bool get_cudnn_layout(cudnnTensorFormat_t& cudnn_layout,
                      library::LayoutTypeID layout);

/// Maps a CUTLASS numeric type to a cuDNN data type enumeration
bool get_cudnn_datatype(cudnnDataType_t& cudnn_element_type,
                        library::NumericTypeID element_type);

/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type
bool get_cudnn_mathtype(cudnnMathType_t& cudnn_math_type,
                        library::ConvDescription const& conv_desc);

/// Returns a status if cudnn can satisfy a particular Conv2d description
Status cudnn_satisfies(library::ConvDescription const& desc,
                       library::Conv2dConfiguration const& configuration);

/// Returns a status if cudnn can satisfy a particular Conv3d description
Status cudnn_satisfies(library::ConvDescription const& desc,
                       library::Conv3dConfiguration const& configuration);

/// Cudnn compute type seems to be hardcoded to float (To handle a possible
/// cudnn issue)
float cast_cudnn_compute_type_to_float(library::NumericTypeID type,
                                       void const* src);

/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate
/// object creation and to destroy cudnnHandle_t on CudnnCreate object
/// destruction. Additionaly, it provides implicit cast from CudnnCreate's
/// object to cudnnHandle_t's object
class CudnnCreate {
private:
    cudnnHandle_t handle;
    cudnnStatus_t status;

public:
    CudnnCreate() { status = cudnnCreate(&handle); }

    ~CudnnCreate() { cudnnDestroy(handle); }

    /// Implicit cast CudnnCreate object to cudnnHandle_t
    operator cudnnHandle_t() const { return handle; }

    /// returns cudnnStatus_t for handle creation
    cudnnStatus_t get_cudnn_create_status() { return status; }
};

namespace detail {

/// Dispatcher to cudnn convolution operators
struct cudnnConvDispatcher {
    //
    // Data members
    //
    // library::Conv2dConfiguration configuration;
    library::ConvArguments arguments;
    library::ConvKind conv_kind;

    // cudnn-specific data structures to fill cudnn API call arguments
    // cudnn activation, filter, and output descriptors
    cudnnTensorDescriptor_t activation_desc;
    cudnnFilterDescriptor_t filter_desc;
    cudnnTensorDescriptor_t output_desc;
    cudnnConvolutionDescriptor_t conv_desc;

    // cudnn datatypes
    cudnnDataType_t data_type_activation;
    cudnnDataType_t data_type_filter;
    cudnnDataType_t data_type_output;

    // cudnn layouts
    cudnnTensorFormat_t layout_activation;
    cudnnTensorFormat_t layout_filter;
    cudnnTensorFormat_t layout_output;

    // cudnn convolution mode
    cudnnConvolutionMode_t conv_mode;

    // cudnn math type (tensorop, tensorop with conversion, simt)
    cudnnMathType_t math_type;

    // cudnn compute data type
    cudnnDataType_t compute_type;

    // cudnn compute type seems to be hardcoded to float (to handle a possible a
    // cudnn issue)
    float alpha;
    float beta;

    // cudnn workspace
    size_t workspace_size_in_bytes = 0;
    cutlass::device_memory::allocation<char> workspace;

    // select cudnn's implicit gemm precomputed algorithm with tensor operations
    static cudnnConvolutionFwdAlgo_t const fprop_algo =
            CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
    static cudnnConvolutionBwdDataAlgo_t const dgrad_algo =
            CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
    static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo =
            CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;

    Status status;

    //
    // Methods
    //

    // TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying
    // Conv2dConfigration

    // ctor for conv2d
    cudnnConvDispatcher(library::ConvDescription const& op_desc,
                        library::Conv2dConfiguration configuration,
                        library::ConvArguments arguments_, cudnnHandle_t handle)
            :  // configuration(configuration_),
              arguments(arguments_),
              conv_kind(op_desc.conv_kind),
              status(Status::kSuccess) {
        bool good = true;

        // Get cudnn datatype, layout, and convolution mode from
        // library::ConvDescription
        good = (good &&
                get_cudnn_datatype(data_type_activation, op_desc.A.element));
        good = (good &&
                get_cudnn_datatype(data_type_filter, op_desc.B.element));
        good = (good &&
                get_cudnn_datatype(data_type_output, op_desc.C.element));
        good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout));
        good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout));
        good = (good && get_cudnn_layout(layout_output, op_desc.C.layout));
        good = (good && get_cudnn_conv_mode(conv_mode,
                                            configuration.problem_size.mode));
        // Get cudnn mathtype (cudnnMathType_t)
        good = (good && get_cudnn_mathtype(math_type, op_desc));
        good = (good &&
                get_cudnn_datatype(compute_type,
                                   op_desc.tile_description.math_instruction
                                           .element_accumulator));
        // Check cutlass Conv2d description has equivalent operator in cudnn
        if (!good) {
            status = Status::kErrorNotSupported;
            return;
        }
        // cudnn compute type seems to be hardcoded to float (to handle a
        // possible a cudnn issue)
        alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue,
                                                 arguments.alpha);
        beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue,
                                                arguments.beta);

        // Create convolution descriptor object
        status = get_cutlass_status(
                cudnnCreateConvolutionDescriptor(&conv_desc));

        // Configure convolution operator
        std::vector<int> padding{configuration.problem_size.pad_h,
                                 configuration.problem_size.pad_w};
        std::vector<int> stride{configuration.problem_size.stride_h,
                                configuration.problem_size.stride_w};
        std::vector<int> dilation{configuration.problem_size.dilation_h,
                                  configuration.problem_size.dilation_w};

        status = get_cutlass_status(cudnnSetConvolutionNdDescriptor(
                conv_desc, op_desc.conv_dim, padding.data(), stride.data(),
                dilation.data(), conv_mode, compute_type));

        // Set groups
        status = get_cutlass_status(cudnnSetConvolutionGroupCount(
                conv_desc, configuration.problem_size.groups));

        // Create activation, filter, and output descriptor objects
        status = get_cutlass_status(
                cudnnCreateTensorDescriptor(&activation_desc));
        status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc));
        status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc));

        // Set activation, filter, and output descriptor
        status = get_cutlass_status(cudnnSetTensor4dDescriptor(
                activation_desc, layout_activation, data_type_activation,
                configuration.problem_size.N, configuration.problem_size.C,
                configuration.problem_size.H, configuration.problem_size.W));

        status = get_cutlass_status(cudnnSetFilter4dDescriptor(
                filter_desc, data_type_filter, layout_filter,
                configuration.problem_size.K, configuration.problem_size.C,
                configuration.problem_size.R, configuration.problem_size.S));

        status = get_cutlass_status(cudnnSetTensor4dDescriptor(
                output_desc, layout_output, data_type_output,
                configuration.problem_size.N, configuration.problem_size.K,
                configuration.problem_size.P, configuration.problem_size.Q));

        // Set math instruction to tensor op
        status = get_cutlass_status(
                cudnnSetConvolutionMathType(conv_desc, math_type));

        // Initialize workspace
        switch (conv_kind) {
            case library::ConvKind::kFprop:
                status = get_cutlass_status(
                        cudnnGetConvolutionForwardWorkspaceSize(
                                handle, activation_desc, filter_desc, conv_desc,
                                output_desc, fprop_algo,
                                &workspace_size_in_bytes));
                break;
            case library::ConvKind::kDgrad:
                status = get_cutlass_status(
                        cudnnGetConvolutionBackwardDataWorkspaceSize(
                                handle, filter_desc, output_desc, conv_desc,
                                activation_desc, dgrad_algo,
                                &workspace_size_in_bytes));
                break;
            case library::ConvKind::kWgrad:
                status = get_cutlass_status(
                        cudnnGetConvolutionBackwardFilterWorkspaceSize(
                                handle, activation_desc, output_desc, conv_desc,
                                filter_desc, wgrad_algo,
                                &workspace_size_in_bytes));
                break;
        }

        workspace = cutlass::device_memory::allocation<char>(
                workspace_size_in_bytes);
    }

    // ctor for conv3d
    cudnnConvDispatcher(library::ConvDescription const& op_desc,
                        library::Conv3dConfiguration configuration,
                        library::ConvArguments arguments_, cudnnHandle_t handle)
            :  // configuration(configuration_),
              arguments(arguments_),
              conv_kind(op_desc.conv_kind),
              status(Status::kSuccess) {
        bool good = true;

        // Get cudnn datatype, layout, and convolution mode from
        // library::ConvDescription
        good = (good &&
                get_cudnn_datatype(data_type_activation, op_desc.A.element));
        good = (good &&
                get_cudnn_datatype(data_type_filter, op_desc.B.element));
        good = (good &&
                get_cudnn_datatype(data_type_output, op_desc.C.element));

        good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout));
        good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout));
        good = (good && get_cudnn_layout(layout_output, op_desc.C.layout));

        good = (good && get_cudnn_conv_mode(conv_mode,
                                            configuration.problem_size.mode));

        // cudnn compute type seems to be hardcoded to float (to handle a
        // possible a cudnn issue)
        alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue,
                                                 arguments.alpha);
        beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue,
                                                arguments.beta);

        good = (good &&
                get_cudnn_datatype(compute_type,
                                   op_desc.tile_description.math_instruction
                                           .element_accumulator));

        // Check cutlass Conv2d description has equivalent operator in cudnn
        if (!good) {
            status = Status::kErrorNotSupported;
        }

        // Create convolution descriptor object
        status = get_cutlass_status(
                cudnnCreateConvolutionDescriptor(&conv_desc));

        // Configure convolution operator
        std::vector<int> padding{configuration.problem_size.pad_d,
                                 configuration.problem_size.pad_h,
                                 configuration.problem_size.pad_w};
        std::vector<int> stride{configuration.problem_size.stride_d,
                                configuration.problem_size.stride_h,
                                configuration.problem_size.stride_w};
        std::vector<int> dilation{configuration.problem_size.dilation_d,
                                  configuration.problem_size.dilation_h,
                                  configuration.problem_size.dilation_w};

        status = get_cutlass_status(cudnnSetConvolutionNdDescriptor(
                conv_desc, op_desc.conv_dim, padding.data(), stride.data(),
                dilation.data(), conv_mode, compute_type));

        // Set groups
        status = get_cutlass_status(cudnnSetConvolutionGroupCount(
                conv_desc, configuration.problem_size.groups));

        // Create activation, filter, and output descriptor objects
        status = get_cutlass_status(
                cudnnCreateTensorDescriptor(&activation_desc));
        status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc));
        status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc));

        // Set activation descriptor
        std::vector<int> activation_extent{
                configuration.problem_size.N, configuration.problem_size.C,
                configuration.problem_size.D, configuration.problem_size.H,
                configuration.problem_size.W};

        std::vector<int> activation_stride{
                configuration.layout_activations.stride()[3], 1,
                configuration.layout_activations.stride()[2],
                configuration.layout_activations.stride()[1],
                configuration.layout_activations.stride()[0]};

        status = get_cutlass_status(cudnnSetTensorNdDescriptor(
                activation_desc, data_type_activation, op_desc.conv_dim + 2,
                activation_extent.data(), activation_stride.data()));

        // Set filter descriptor
        std::vector<int> filter_extent{
                configuration.problem_size.K, configuration.problem_size.C,
                configuration.problem_size.T, configuration.problem_size.R,
                configuration.problem_size.S};

        std::vector<int> filter_stride{
                configuration.layout_filters.stride()[3], 1,
                configuration.layout_filters.stride()[2],
                configuration.layout_filters.stride()[1],
                configuration.layout_filters.stride()[0]};

        status = get_cutlass_status(cudnnSetFilterNdDescriptor(
                filter_desc, data_type_filter, layout_filter,
                op_desc.conv_dim + 2, filter_extent.data()));

        // Set output descriptor
        std::vector<int> output_extent{
                configuration.problem_size.N, configuration.problem_size.K,
                configuration.problem_size.Z, configuration.problem_size.P,
                configuration.problem_size.Q};

        std::vector<int> output_stride{configuration.layout_output.stride()[3],
                                       1,
                                       configuration.layout_output.stride()[2],
                                       configuration.layout_output.stride()[1],
                                       configuration.layout_output.stride()[0]};

        status = get_cutlass_status(cudnnSetTensorNdDescriptor(
                output_desc, data_type_output, op_desc.conv_dim + 2,
                output_extent.data(), output_stride.data()));

        // Set math instruction to tensor op
        status = get_cutlass_status(
                cudnnSetConvolutionMathType(conv_desc, math_type));

        // Initialize workspace
        switch (conv_kind) {
            case library::ConvKind::kFprop:
                status = get_cutlass_status(
                        cudnnGetConvolutionForwardWorkspaceSize(
                                handle, activation_desc, filter_desc, conv_desc,
                                output_desc, fprop_algo,
                                &workspace_size_in_bytes));
                break;
            case library::ConvKind::kDgrad:
                status = get_cutlass_status(
                        cudnnGetConvolutionBackwardDataWorkspaceSize(
                                handle, filter_desc, output_desc, conv_desc,
                                activation_desc, dgrad_algo,
                                &workspace_size_in_bytes));
                break;
            case library::ConvKind::kWgrad:
                status = get_cutlass_status(
                        cudnnGetConvolutionBackwardFilterWorkspaceSize(
                                handle, activation_desc, output_desc, conv_desc,
                                filter_desc, wgrad_algo,
                                &workspace_size_in_bytes));
                break;
        }

        workspace = cutlass::device_memory::allocation<char>(
                workspace_size_in_bytes);
    }

    /// Executes Conv2d operater from cudnn library
    cudnnStatus_t operator()(cudnnHandle_t handle) {
        switch (conv_kind) {
            case library::ConvKind::kFprop:
                return cudnnConvolutionForward(
                        handle, &alpha, activation_desc, activation(),
                        filter_desc, filter(), conv_desc, fprop_algo,
                        workspace.get(), workspace_size_in_bytes, &beta,
                        output_desc, arguments.D);
            case library::ConvKind::kDgrad:
                return cudnnConvolutionBackwardData(
                        handle, &alpha, filter_desc, filter(), output_desc,
                        output(), conv_desc, dgrad_algo, workspace.get(),
                        workspace_size_in_bytes, &beta, activation_desc,
                        arguments.D);
            case library::ConvKind::kWgrad:
                return cudnnConvolutionBackwardFilter(
                        handle, &alpha, activation_desc, activation(),
                        output_desc, output(), conv_desc, wgrad_algo,
                        workspace.get(), workspace_size_in_bytes, &beta,
                        filter_desc, arguments.D);
            default:
                throw std::runtime_error(
                        "Invalid Conv Operator (fprop, dgrad, wgrad)");
        }
    }

    // Returns Actviation Tensor
    void const* activation() const {
        switch (conv_kind) {
            case library::ConvKind::kFprop:
                return arguments.A;
            case library::ConvKind::kDgrad:
                return arguments.C;
            case library::ConvKind::kWgrad:
                return arguments.B;
            default:
                throw std::runtime_error(
                        "Invalid Conv Operator (fprop, dgrad, wgrad)");
        }
    }

    // Returns Filter Tensor
    void const* filter() const {
        switch (conv_kind) {
            case library::ConvKind::kFprop:
                return arguments.B;
            case library::ConvKind::kDgrad:
                return arguments.B;
            case library::ConvKind::kWgrad:
                return arguments.C;
            default:
                throw std::runtime_error(
                        "Invalid Conv Operator (fprop, dgrad, wgrad)");
        }
    }

    // Returns Output Tensor
    void const* output() const {
        switch (conv_kind) {
            case library::ConvKind::kFprop:
                return arguments.C;
            case library::ConvKind::kDgrad:
                return arguments.A;
            case library::ConvKind::kWgrad:
                return arguments.A;
            default:
                throw std::runtime_error(
                        "Invalid Conv Operator (fprop, dgrad, wgrad)");
        }
    }
};

}  // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif  //#if CUTLASS_ENABLE_CUDNN
}  // namespace profiler
}  // namespace cutlass
